{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) University of Strasbourg. All Rights Reserved.\n",
    "# ORPose-Depth\n",
    "**Human Pose Estimation on Privacy-Preserving Low-Resolution Depth Images (MICCAI-2019)**\n",
    "\n",
    "_Vinkle Srivastav, Afshin Gangi, Nicolas Padoy_\n",
    "\n",
    "[![arXiv](https://img.shields.io/badge/arxiv-2007.08340-red)](https://arxiv.org/abs/2007.08340) \n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CAMMA-public/ORPose-Depth/blob/master/orpose_depth_demo.ipynb)\n",
    "\n",
    "------------------------------\n",
    "\n",
    "**This demo notebook contains the inference and evaluation scripts for the following models: DepthPose_80x60 and DepthPose_64x48**\n",
    "\n",
    "**Code below needed only for the colab demo. Please make sure to enable \"GPU\" using EDIT->Notebook settings**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the code, models and the data\n",
    "!git clone https://github.com/CAMMA-public/ORPose-Depth.git\n",
    "%cd ORPose-Depth\n",
    "\n",
    "!wget https://s3.unistra.fr/camma_public/github/DepthPose/models.zip\n",
    "!wget https://s3.unistra.fr/camma_public/github/DepthPose/data.zip\n",
    "!unzip -q models.zip\n",
    "!unzip -q data.zip\n",
    "!rm models.zip data.zip\n",
    "!pip install yacs\n",
    "\n",
    "# Build the NMS library\n",
    "%cd lib\n",
    "!make\n",
    "%cd .."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import os, sys\n",
    "import subprocess\n",
    "import glob\n",
    "import torch\n",
    "import torchvision\n",
    "import random\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "\n",
    "# to run the video\n",
    "from IPython.display import HTML, display\n",
    "from base64 import b64encode\n",
    "\n",
    "import cv2\n",
    "import numpy as np\n",
    "lib_dir = os.path.join(os.getcwd(),\"lib\") \n",
    "if lib_dir not in sys.path:\n",
    "    sys.path.insert(0, lib_dir)\n",
    "\n",
    "# add the local library code\n",
    "from core.inference import get_poses\n",
    "from models.depthpose_x8 import get_model\n",
    "from dataset.mvor import MVORDatasetTest\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.vis_utils import VisUtils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the model type and set the paths to the data and model directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose model type\n",
    "MODEL_TYPE = \"x10\" # or x8\n",
    "\n",
    "# paths and parameters\n",
    "DATA_DIR = \"data\"\n",
    "MODELS_DIR = \"models\"\n",
    "PARAMS = {\"thre1\": 0.1, \"thre2\": 0.05, \"thre3\": 0.5}\n",
    "USE_GPU = True\n",
    "\n",
    "if torch.cuda.is_available() and USE_GPU:\n",
    "    DEVICE = torch.device(\"cuda\")\n",
    "else:\n",
    "    DEVICE = torch.device(\"cpu\")\n",
    "\n",
    "print(\"Using device:\", DEVICE)\n",
    "\n",
    "COLOR_IMAGES_DIR = os.path.join(DATA_DIR, \"images/MVOR/LR_x5_color\")\n",
    "if MODEL_TYPE == \"x8\":\n",
    "    MODEL_PATH = os.path.join(MODELS_DIR, \"depthpose_80x60.pth\")\n",
    "    DEPTH_IMAGES_DIR = os.path.join(DATA_DIR, \"MOVR/MVOR/images/LR_x8\")\n",
    "    PARAMS[\"scale\"] = 8\n",
    "elif MODEL_TYPE == \"x10\":\n",
    "    MODEL_PATH = os.path.join(MODELS_DIR, \"depthpose_64x48.pth\")\n",
    "    DEPTH_IMAGES_DIR = os.path.join(DATA_DIR, \"images/MVOR/LR_x10\")\n",
    "    PARAMS[\"scale\"] = 10\n",
    "else:\n",
    "    print(\"please select proper model!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_files():\n",
    "    \"\"\" Read the color and depth paths\n",
    "    Returns:\n",
    "        paths to the color and depth images\n",
    "    \"\"\"\n",
    "    files = [\n",
    "        f[\"file_name\"]\n",
    "        for f in json.load(open(os.path.join(DATA_DIR, \"annotations/mvor_eval_depth_2018.json\")))[\n",
    "            \"images\"\n",
    "        ]\n",
    "    ]\n",
    "    files = [\n",
    "        (os.path.join(DEPTH_IMAGES_DIR, f), os.path.join(COLOR_IMAGES_DIR, f.replace(\"depth\", \"color\")))\n",
    "        for f in files\n",
    "    ]\n",
    "    print(\"Found {} depth and color files\".format(len(files)))\n",
    "    return files\n",
    "\n",
    "def progress_bar(value, max=100):\n",
    "    \"\"\" A HTML helper function to display the progress bar\n",
    "    Args:\n",
    "        value ([int]): [current progress bar value]\n",
    "        max (int, optional): [maximum value]. Defaults to 100.\n",
    "    Returns:\n",
    "        [str]: [HTML progress bar string]\n",
    "    \"\"\"    \n",
    "    return HTML(\"\"\"\n",
    "        <progress\n",
    "            value='{value}'\n",
    "            max='{max}',\n",
    "            style='width: 100%'\n",
    "        >\n",
    "            {value}\n",
    "        </progress>\n",
    "    \"\"\".format(value=value, max=max)) \n",
    "\n",
    "def bgr2rgb(im):\n",
    "    \"\"\"[convert opencv image in BGR format to RGB format]\n",
    "    Args:\n",
    "        im ([numpy.ndarray]): [input image in BGR format]\n",
    "    Returns:\n",
    "        [numpy.ndarray]: [output image in RGB format]\n",
    "    \"\"\"    \n",
    "    b, g, r = cv2.split(im)\n",
    "    return cv2.merge([r, g, b])\n",
    "\n",
    "def add_border(im, bordersize=25, mean=0):\n",
    "    \"\"\"\n",
    "    [add border around the image]\n",
    "    Args:\n",
    "        im ([numpy.ndarray]): [input image]\n",
    "    Returns:\n",
    "        [numpy.ndarray]: [output image]\n",
    "    \"\"\"\n",
    "    return cv2.copyMakeBorder(\n",
    "        im,\n",
    "        top=bordersize,\n",
    "        bottom=bordersize,\n",
    "        left=bordersize,\n",
    "        right=bordersize,\n",
    "        borderType=cv2.BORDER_CONSTANT,\n",
    "        value=[mean, mean, mean],\n",
    "    )\n",
    "\n",
    "\n",
    "def images_to_video(img_folder, output_vid_file, fps=20):\n",
    "    \"\"\"[convert png images to video using ffmpeg]\n",
    "    Args:\n",
    "        img_folder ([str]): [path to images]\n",
    "        output_vid_file ([str]): [Name of the output video file name]\n",
    "    \"\"\"\n",
    "    os.makedirs(img_folder, exist_ok=True)\n",
    "    command = [\n",
    "        \"ffmpeg\",\n",
    "        \"-y\",\n",
    "        \"-framerate\",\n",
    "        str(fps),\n",
    "        \"-threads\",\n",
    "        \"16\",\n",
    "        \"-i\",\n",
    "        f\"{img_folder}/%06d.png\",\n",
    "        \"-profile:v\",\n",
    "        \"baseline\",\n",
    "        \"-level\",\n",
    "        \"3.0\",\n",
    "        \"-c:v\",\n",
    "        \"libx264\",\n",
    "        \"-pix_fmt\",\n",
    "        \"yuv420p\",\n",
    "        \"-an\",\n",
    "        \"-v\",\n",
    "        \"error\",\n",
    "        output_vid_file,\n",
    "    ]\n",
    "    print(f'\\nRunning \"{\" \".join(command)}\"')\n",
    "    subprocess.call(command)\n",
    "    print(\"\\nVideo generation finished\")\n",
    "\n",
    "# get the files\n",
    "files = get_files()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the model and load the weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_model()\n",
    "model.load_state_dict(torch.load(MODEL_PATH), strict=True)\n",
    "model = model.to(DEVICE)\n",
    "# print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference on the low-resolution depth image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# initialize the visulization rendering object\n",
    "vis_utils = VisUtils() \n",
    "\n",
    "# Read a random low-resolution depth image and corresponding color image\n",
    "path_depth, path_color = random.choice(files)\n",
    "\n",
    "# Convert depth image to PyTorch tensor and display the size\n",
    "im_depth = torchvision.transforms.functional.to_tensor(Image.open(path_depth)).unsqueeze_(0)\n",
    "im_depth = im_depth.to(DEVICE)\n",
    "print(\"size of the input depth image: {}x{}\".format(im_depth.shape[-1], im_depth.shape[-2]))\n",
    "\n",
    "# Read the corresponding color image \n",
    "im_color = cv2.resize(bgr2rgb(cv2.imread(path_color)), (vis_utils.width, vis_utils.height))\n",
    "\n",
    "# Get the keypoints by running the inference code\n",
    "keypoints = get_poses(model, im_depth, PARAMS)\n",
    "\n",
    "# Render the keypoints on the depth image\n",
    "im_depth = im_depth.cpu().squeeze().numpy()\n",
    "im_depth = vis_utils.render(im_depth, keypoints)    \n",
    "\n",
    "# Plot the resuls\n",
    "fig = plt.figure(figsize=(20,10))\n",
    "fig.add_subplot(1, 2, 1); plt.imshow(im_depth); plt.title(\"Output depth\"); plt.axis(\"off\")\n",
    "fig.add_subplot(1, 2, 2); plt.imshow(im_color); plt.title(\"Corresponding color image\"); plt.axis(\"off\")\n",
    "plt.show()\n",
    "\n",
    "# Run the cell again to check the result on different random image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference demo on the video frames"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set the paths for the video frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the paths\n",
    "IMGS_PATH_DEPTH = \"data/images/MVOR_seq/LR_x10\"\n",
    "IMGS_PATH_COLOR = \"data/images/MVOR_seq/LR_x10_color\"\n",
    "\n",
    "\n",
    "# Name of the output video file\n",
    "OUTPUT_DIR = \"output\"\n",
    "OUTPUT_VID_NAME = os.path.join(OUTPUT_DIR, \"output.mp4\")\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "        \n",
    "# Read the paths and put it in a list\n",
    "files_depth = sorted(glob.glob(IMGS_PATH_DEPTH + \"/*.png\"))\n",
    "files_color = sorted(glob.glob(IMGS_PATH_COLOR + \"/*.png\"))\n",
    "files = [(d,c) for d,c in zip(files_depth, files_color)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference and rendering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"Running inference on video frames\")\n",
    "# Initialize the progress-bar\n",
    "out = display(progress_bar(1, len(files)), display_id=True)\n",
    "\n",
    "font = cv2.FONT_HERSHEY_SIMPLEX\n",
    "org = (250, 25)\n",
    "fontScale = 0.5\n",
    "color = (255, 255, 255) # white\n",
    "thickness = 1   \n",
    "\n",
    "# Run the inference on the low-resolution depth frames\n",
    "for index, (path_depth, path_color) in enumerate(files):\n",
    "    im_depth = torchvision.transforms.functional.to_tensor(Image.open(path_depth)).unsqueeze_(0)\n",
    "    im_depth = im_depth.to(DEVICE)\n",
    "    im_color = cv2.imread(path_color)\n",
    "    im_color = cv2.resize(im_color, (80, 60))\n",
    "    im_color = cv2.resize(im_color, (vis_utils.width, vis_utils.height))\n",
    "    keypoints = get_poses(model, im_depth, PARAMS)\n",
    "    im_depth = im_depth.cpu().squeeze().numpy()\n",
    "        \n",
    "    im_depth = vis_utils.render(im_depth, keypoints, bgr2rgb=False)\n",
    "    im_depth = add_border(im_depth, bordersize=35)\n",
    "    im_color = add_border(im_color, bordersize=35)    \n",
    "    im_depth = cv2.putText(im_depth, \"Output on low-resolution depth images\", (200, 25), font, fontScale, color, thickness, cv2.LINE_AA)    \n",
    "    im_color = cv2.putText(im_color, \"Corresponding color images\", (250, 25), font, fontScale, color, thickness, cv2.LINE_AA)    \n",
    "    vis = np.concatenate((im_depth, im_color), axis=1)\n",
    "    cv2.imwrite(os.path.join(OUTPUT_DIR, f\"{index:06d}\" + \".png\"), vis)\n",
    "    out.update(progress_bar(index + 1, len(files)))\n",
    "\n",
    "# Convert the rendered images to video\n",
    "images_to_video(OUTPUT_DIR, OUTPUT_VID_NAME, fps=10)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show the output video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mp4 = open(OUTPUT_VID_NAME, \"rb\").read()\n",
    "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
    "HTML(\"\"\" <video width=800 controls>\n",
    "         <source src=\"%s\" type=\"video/mp4\">\n",
    "         </video> \n",
    "     \"\"\" % data_url\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation code for **DepthPose_64x48** and **DepthPose_80x60** on the MVOR dataset. \n",
    "\n",
    "**Each model takes approximately 40 minutes for the evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To run the evaluation for DepthPose_64x48 model\n",
    "!python tools/eval_mvor.py --config_file experiments/mvor/DepthPose_64x48.yaml                     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To run the evaluation for DepthPose_80x60 model\n",
    "!python tools/eval_mvor.py --config_file experiments/mvor/DepthPose_80x60.yaml"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
